package deeplearn;

import captcharecognition.MulRecordDataLoader;
import captcharecognition.MultiRecordDataSetIterator;
import org.deeplearning4j.core.storage.StatsStorage;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.graph.StackVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.InvocationType;
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.model.stats.StatsListener;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.io.File;

/**
 * 说明：
 * 博客：https://blog.csdn.net/dong_lxkm/article/details/103125742
 *     1、dl4j并没有提供像keras那样冻结某些层参数的方法，这里采用设置learningrate为0的方法，来冻结某些层的参数
 *
 *     2、这个的更新器，用的是sgd，不能用其他的（比方说Adam、Rmsprop），因为这些自适应更新器会考虑前面batch的梯度作为本次更新的梯度，达不到不更新参数的目的
 *
 *     3、这里用了StackVertex，沿着第一维合并张量，也就是合并真实数据样本和Generator产生的数据样本，共同训练Discriminator
 *
 *     4、训练过程中多次update   Discriminator的参数，以便量出最大距离，让后更新Generator一次
 *
 *     5、进行10w次迭代
		6、数据被下载到C:\Users\Administrator\.deeplearning4j\data\MNIST
 */
public class YzmGan {

	static double lr = 0.01;
	static String model = "F:/face/ganYzm.zip";
	public static void main(String[] args) throws Exception {
 
		final NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new Sgd(lr))
				.weightInit(WeightInit.XAVIER);
 
		final GraphBuilder graphBuilder = builder.graphBuilder().backpropType(BackpropType.Standard)
				.addInputs("input1", "input2")
				/*.setInputTypes(InputType.convolutional(MulRecordDataLoader.height,MulRecordDataLoader.width,MulRecordDataLoader.channels),
						InputType.convolutional(MulRecordDataLoader.height,MulRecordDataLoader.width,MulRecordDataLoader.channels))*/
				 .addLayer("g1",  new ConvolutionLayer.Builder(new int[]{5, 5}, new int[]{1, 1}, new int[]{0, 0})
						.nIn(MulRecordDataLoader.channels).nOut(48).activation( Activation.RELU).build(), "input1")
				.addLayer("g2",  new ConvolutionLayer.Builder(new int[]{5, 5}, new int[]{1, 1}, new int[]{0, 0})
						.nOut(64).activation( Activation.RELU).build(), "g1")
				.addLayer("g3",  new ConvolutionLayer.Builder(new int[]{3, 3}, new int[]{1, 1}, new int[]{0, 0})
						.nOut(128).activation( Activation.RELU).build(), "g2")
				.addLayer("g4",  new ConvolutionLayer.Builder(new int[]{4, 4}, new int[]{1, 1}, new int[]{0, 0})
						.nOut(256).activation( Activation.RELU).build(), "g3")
				.addLayer("g5",  new ConvolutionLayer.Builder(new int[]{4, 4}, new int[]{1, 1}, new int[]{0, 0})
						.nOut(512).activation( Activation.RELU).build(), "g4")
				.addLayer("g6",  new ConvolutionLayer.Builder(new int[]{4, 4}, new int[]{1, 1}, new int[]{0, 0})
						.nOut(1024).activation( Activation.RELU).build(), "g5")

				.addVertex("stack", new StackVertex(), "input2", "g6")

				.addLayer("d1", new ConvolutionLayer.Builder(new int[]{3, 3}).hasBias(false)
						.nOut(256 * 4).activation(new ActivationLReLU(0.1)).build(), "stack")
				.addLayer("d2", new ConvolutionLayer.Builder(new int[]{3, 3}).hasBias(false)
						.nOut(128 * 4).activation(new ActivationLReLU(0.1)).build(), "d1")
				.addLayer("d3", new ConvolutionLayer.Builder(new int[]{3, 3}).hasBias(false)
						.nOut(64 * 4).activation(new ActivationLReLU(0.1)).build(), "d2")
				.addLayer("d4", new ConvolutionLayer.Builder(new int[]{3, 3}).hasBias(false)
						.nOut(64).activation(new ActivationLReLU(0.2)).build(), "d3")
				.addLayer("d5", new ConvolutionLayer.Builder(new int[]{3, 3}).hasBias(false)
						.nOut(64).activation(new ActivationLReLU(0.2)).build(), "d4")
				.addVertex("merge1",new MergeVertex(),"d5","d3")
				.addLayer("d6", new ConvolutionLayer.Builder(new int[]{3, 3}).hasBias(false)
						.nOut(64).activation(new ActivationLReLU(0.2)).build(), "merge1")
				.addLayer("d7", new ConvolutionLayer.Builder(new int[]{3, 3}).hasBias(false)
						.nOut(64).activation(new ActivationLReLU(0.2)).build(), "d6")
				.addVertex("merge2",new MergeVertex(),"d7","merge1")
				.addLayer("d8", new ConvolutionLayer.Builder(new int[]{5, 5}).hasBias(false)
						.nOut(1).activation(Activation.SIGMOID).build(), "merge2")
				.addLayer("d9", new ConvolutionLayer.Builder(new int[]{5, 5}).hasBias(false)
						.nOut(3).activation(Activation.TANH).build(), "d8")
				.addVertex("concatenate",new StackVertex(),"d8","d9")
				.addLayer("d10", new ConvolutionLayer.Builder(new int[]{4, 4}).hasBias(false)
				.nOut(1).activation(Activation.SIGMOID).build(), "concatenate")
				.setOutputs("out");
 
		ComputationGraph net = new ComputationGraph(graphBuilder.build());
		/*if (new File(model).exists()) {
			net = ComputationGraph.load(new File(model), true);
		}else{
			net = new ComputationGraph(graphBuilder.build());
		}*/
		net.init();
		System.out.println(net.summary());
		UIServer uiServer = UIServer.getInstance();
		StatsStorage statsStorage = new InMemoryStatsStorage();
		uiServer.attach(statsStorage);
		//net.setListeners(new ScoreIterationListener(100));
		net.getLayers();

		MultiDataSetIterator train = new MultiRecordDataSetIterator(30, "train");
		MultiDataSetIterator testMulIterator = new MultiRecordDataSetIterator(30,"yzmtest");
		//MultiDataSet dataSetD = new MultiDataSet(train.,testMulIterator);

		//fit
		net.setListeners(new ScoreIterationListener(10), new StatsListener( statsStorage), new EvaluativeListener(testMulIterator, 1, InvocationType.EPOCH_END));
		int epochs = 4;

		INDArray[] labelD =  new INDArray[]{Nd4j.ones(MulRecordDataLoader.height,MulRecordDataLoader.width,MulRecordDataLoader.channels),Nd4j.ones(MulRecordDataLoader.height,MulRecordDataLoader.width,MulRecordDataLoader.channels)};
		INDArray labelG = Nd4j.ones(41,141,1024);
		for (int i = 1; i <= 100000; i++) {
			if (!train.hasNext()) {
				train.reset();
			}
			INDArray[] trueExp = train.next().getFeatures();

			INDArray z = Nd4j.rand(new NormalDistribution(),new long[] { 30, 10 });
			MultiDataSet dataSetD = new MultiDataSet(trueExp,labelD);
			for(int m=0;m<10;m++){
				trainD(net, dataSetD);
			}
			z = Nd4j.rand(new NormalDistribution(),new long[] { 30, 10 });
			MultiDataSet dataSetG = new MultiDataSet(trueExp,
					new INDArray[] { labelG });
			trainG(net, dataSetG);

			if (i % 10000 == 0) {
				net.save(new File(model), true);
			}

		}

		//DataSetIterator train = new MnistDataSetIterator(30, true, 12345);
		//按垂直方向（行顺序）堆叠数组构成一个新的数组
		/*INDArray labelD = Nd4j.vstack(Nd4j.ones(30, 1), Nd4j.zeros(30, 1));
 
		INDArray labelG = Nd4j.ones(60, 1);
 
		for (int i = 1; i <= 100000; i++) {
			if (!train.hasNext()) {
				train.reset();
			}
			INDArray[] trueExp = train.next().getFeatures();

			INDArray z = Nd4j.rand(new NormalDistribution(),new long[] { 30, 10 });
			MultiDataSet dataSetD = new MultiDataSet(trueExp,
					new INDArray[] { labelD });
			for(int m=0;m<10;m++){
				trainD(net, dataSetD);
			}
			z = Nd4j.rand(new NormalDistribution(),new long[] { 30, 10 });
			MultiDataSet dataSetG = new MultiDataSet(trueExp,
					new INDArray[] { labelG });
			trainG(net, dataSetG);
 
			if (i % 10000 == 0) {
			   net.save(new File(model), true);
			}
 
		}*/
 
	}
 	// 判别模型  D(x)
	public static void trainD(ComputationGraph net, MultiDataSet dataSet) {
		net.setLearningRate("g1", 0);
		net.setLearningRate("g2", 0);
		net.setLearningRate("g3", 0);
		net.setLearningRate("g4", 0);
		net.setLearningRate("g5", 0);
		net.setLearningRate("g6", 0);
		net.setLearningRate("d1", lr);
		net.setLearningRate("d2", lr);
		net.setLearningRate("d3", lr);
		net.setLearningRate("d4", lr);
		net.setLearningRate("d5", lr);
		net.setLearningRate("d6", lr);
		net.setLearningRate("d7", lr);
		net.setLearningRate("d8", lr);
		net.setLearningRate("d9", lr);
		net.fit(dataSet);
	}
	//生成模型 g(z)
	public static void trainG(ComputationGraph net, MultiDataSet dataSet) {
		net.setLearningRate("g1", lr);
		net.setLearningRate("g2", lr);
		net.setLearningRate("g3", lr);
		net.setLearningRate("g4", lr);
		net.setLearningRate("g5", lr);
		net.setLearningRate("g6", lr);
		net.setLearningRate("d1", 0);
		net.setLearningRate("d2", 0);
		net.setLearningRate("d3", 0);
		net.setLearningRate("d4", 0);
		net.setLearningRate("d5", 0);
		net.setLearningRate("d6", 0);
		net.setLearningRate("d7", 0);
		net.setLearningRate("d8", 0);
		net.setLearningRate("d9", 0);
		net.fit(dataSet);
	}
}